function savings_EWM_quadrant = generate_results_quadrant_rule(bsimax,sample,wave,tcost,covariate,winter_prevuse,SMC,fullsample,savedir,fbase)
% This function calculates cost savings or energy reduction from EWM quadrant 
% rules using grid search. 

% Inputs: 
% (1) bsimax: number of bootstrap runs
% (2) sample: cross-sectional dataset with electricity consumption, 
% covariates, and propensity scores
% (3) wave: indicates whether the output is wave-specific (=3,6,or 7) or
% for the pooled sample (=0)
% (4) tcost: private marginal cost of implementing the program per household
%  >0 represents cost savings; =0 represents kWh reduction 
% (5) covariate: indicates covariate for analysis: income, size, vintage,
% minimum of baseline consumption, maximum of baseline consumption, or
% standard deviation of consumption
% (6) winter_prevuse: indicates whether baseline consumption is calculated 
% as the mean of consumption in winter months (Jan and Feb) or as the mean
% of specified pre-treatment periods.
% (7) SMC: =1 use social marginal cost; =0 use retail electricity price
% (8) fullsample: =1 if full sample, =0 if estimation sample
% (9) savedir: directory to save the savings table 
% (10) fbase: string used to indicate the propensity score and baseline
% months specification for output table 

% Output: 
% (1) savings_EWM_quadrant: a table summarizing the savings point estimate,
% share of households to treat, total savings for the entire state, and 95%
% confidence intervals for the point estimate 

%% Gather inputs
[X,Y,D,n,ps]=generate_inputs_quadrant_rule(sample,wave,tcost,covariate,SMC);

%% Estimate g using IPW
g = Y.*(D./ps - (1-D)./(1-ps)); % elements of W(G) - W(G_0)

%% Compress data by histogram

% separation in the two dimensions
if covariate=="income"
    dx1 = 5;    % step size is 5000 for income
    dx2 = 10;   % step size is 10 for pre-treatment consumption
elseif covariate=="size"
    dx1 = 100;
    dx2 = 10;
elseif covariate=="vintage"
    dx1=10;
    dx2=10;
elseif covariate=="min"
    dx1 = 10; % step size is 10 for min/max/std of baseline consumption
    dx2 = 10; % step size is 10 for mean of baseline consumption
% X first column is other measures of baseline consumption, X second column
% is mean baseline consumption 
elseif covariate=="max"
    dx1 = 10;
    dx2 = 10;
elseif covariate=="std"
    dx1 = 10; 
    dx2 = 10; 
end

% sort rows of X -> [Xr, gr]
[Xr, Ind] = sortrows(X); % sorts rows by column (first to last col)
gr = g(Ind);

% use histcounts2 to assign each household to its bin in covariate space
[~,boo1,boo2,xx1,xx2] = histcounts2(Xr(:,1),Xr(:,2),...
    'BinWidth',[dx1,dx2],'Normalization','count');
%boo1: the lower bound of each bin for covariate
%boo2: the lower bound of each bin for prev_use
%xx1: each data point belongs to which covariate bin 
%xx2: each data point belongs to which prev_use bin 

% aggregate data to the bin level
nu = length(boo1)*length(boo2); % # grid points = # cov bins * # prev_use bins
Xu = nan(nu,2); gu = nan(nu,1); nw = nan(nu,1);
k = 1;
for i = 1:length(boo1)
   for j = 1:length(boo2)
       bin_boo = (xx1==i & xx2==j);  % identify households in bin ij
       Xu(k,:) = [boo1(i), boo2(j)]; % coarsened X for bin ij (lower bound)
       gu(k,1) = sum(gr(bin_boo));   % sum gr w/in bin ij
       nw(k,1) = sum(bin_boo);       % num households in bin ij
       k=k+1;
   end
end

% create x1, x2 variables for convenience
x1 = Xu(:,1);
x2 = Xu(:,2);

%% Minimize cost
tic1 = tic; % time grid search

% initialize variables for grid search
grid_x1 = boo1'; grid_x2 = boo2';
k1 = size(grid_x1,1); k2 = size(grid_x2,1);
m_W = Inf*ones(k1,k2);
minW = Inf;

% search grid for optimal quadrant rule
for i1=1:k1
    boo_i1=(x1 - grid_x1(i1));
    for i2=1:k2
        boo_i2=(x2 - grid_x2(i2));
        for sign1 = -1:2:1 %for i = [-1,1]
            for sign2 = -1:2:1
                % treat if both threshold conditions are met
                select = ((sign1.*boo_i1)>=0).*((sign2.*boo_i2)>=0); 
                W = sum(gu.*select); % total savings (divide by n for /hh)
                if W<m_W(i1,i2) % min cost: find lowest W quadrant at i1,i2
                   m_W(i1,i2) = W; 
                end
                if (W < minW) % min cost: find lowest W quadrant overall
                    min_i1 = i1;
                    min_i2 = i2;
                    min_sign1 = sign1;
                    min_sign2 = sign2;
                    minW = W;        
                end
            end
        end
    end
end

fprintf('Grid search takes %.2f sec\n',toc(tic1))

%% Create variables for savings output (at bottom after bootstrap)

% treatment status for each unique household 
in_Ghat = (min_sign1*(x1 - grid_x1(min_i1))>=0).*(min_sign2*(x2 - grid_x2(min_i2))>=0);

% get in_Ghat from unique household level back to household level 
    in_Ghat_hh = nan(size(g)); % Treatment Assignment from Optimization
    i_u=1;

    for i=1:length(boo1)
        for j=1:length(boo2)
            % individuals with the same (x1,x2)
            bin_boo = logical((xx1==i).*(xx2==j));
            % optimal treatment for people with (x1,x2)
            G_ = in_Ghat(i_u);
            % let all individuals have this assignment
            in_Ghat_hh(bin_boo,1) = G_; % in_Ghat_hh is sorted=> corr to (Xr, gr)
            % step to the next element of Xu/gu/in_Ghat
            i_u=i_u+1;
        end
    end
    
% grid indices
v_x1 = grid_x1;
v_x2 = grid_x2;

if bsimax==0 % if no bootstrap, initiliaze vhats so table code doesn't fail
    vhats_n=0;
    vhats_p=0;
end

%% Calculate confidence interval 

rng(0);
% generate suffix for bootstrap folder 
if (tcost)
    if (SMC)
        s_speci = 'smc';
    else
        s_speci = 'pmc';
    end
else
    s_speci='kwh';
end

if bsimax>0
    vhats_p = zeros(bsimax,1); 
    vhats_n = zeros(bsimax,1); 
    scaled_ATT_itr = zeros(bsimax,1);
    ATE_itr= zeros(bsimax,1);
    W_bs_Ghat_itr = zeros(bsimax,1);
    
    % shared name for each bootstrap batch
    s_now_batch=datestr(now(),'yyyymmdd_HHMMSS');
    % Create a saving directory for the i_BS files
     fpath_BS = ['.\intermediate_data\EWM_results\BS\quadrant\',...
     s_now_batch,'_',covariate,'_',s_speci,'\'];
    if ~exist(fpath_BS,'dir')
       mkdir(fpath_BS); 
    end

    tic
    dt_start = now();
    for bsi = 1:bsimax
        % DRAW BOOTSTRAP INDEXES
        bsperm = randi(n,[n 1]);
        % Stack resampled g and original -g to subtract Vn from Vbs (where V. = W.(Ghat) - W.(no treatment))
        % W. and V. are interchangeable bc baseline welfare differences out
        gr = [g(bsperm,:); -g];
        Xr = [X(bsperm,:); X];
        % Compress the data with identical X's
        [Xr, Ind_bs] = sortrows(Xr);
        gr = gr(Ind_bs);
        [~,coo1,coo2,xx1,xx2] = histcounts2(Xr(:,1),Xr(:,2),...
            'BinWidth',[dx1,dx2],'Normalization','count'); %
        bnu = length(coo1)*length(coo2);
        bXu = nan(bnu,2); bgu = nan(bnu,1); bnw = nan(bnu,1);
        kk=1;
        for i = 1:length(coo1)
           for j = 1:length(coo2)
               bin_coo = logical((xx1==i).*(xx2==j));
               bXu(kk,:)=[coo1(i),coo2(j)];
               bgu(kk,1) = sum(gr(bin_coo));
               bnw(kk,1) = sum(bin_coo);
               kk=kk+1;
           end
        end

        % Maximize and minimize Wn-Wbs
        maxWb = -Inf; minWb = Inf;


        for i1=1:k1
            boo_i1=(x1 - grid_x1(i1));
            for i2=1:k2
                boo_i2=(x2 - grid_x2(i2));
                for sign1 = -1:2:1 %for i = [-1,1]
                    for sign2 = -1:2:1
                        select = ((sign1.*boo_i1)>=0).*((sign2.*boo_i2)>=0); 
                        Wb = sum(bgu.*select);
                        maxWb = max(maxWb,Wb);
                        minWb = min(minWb,Wb);
                    end
                end
            end
        end
        vhats_p(bsi) = maxWb;
        vhats_n(bsi) = minWb;
        sprintf("Iteration %d completed after %0.2g minutes", bsi, toc/60)
        dt_cmpt = now();
    %% calculate scaled ATT and ATE for each bootstrap sample 
    scaled_ATT_bs = (sum(D(bsperm,:))/n).*...
        (sum( D(bsperm,:).*Y(bsperm,:) - ((1-D(bsperm,:)).*Y(bsperm,:).*ps(bsperm,:))./(1-ps(bsperm,:)) )/sum(D(bsperm,:)));
    ATE_bs = mean(g(bsperm,:)); % equiv. to ATE = mean(D.*Y./ps)-mean(((1-D).*Y)./(1-ps));
    
    %% apply G_hat to bootstrap sample 
    in_Ghat_on_bs = (min_sign1*(bXu(:,1) - grid_x1(min_i1))>=0).*(min_sign2*(bXu(:,2) - grid_x2(min_i2))>=0);
    W_bs_Ghat = sum(in_Ghat_on_bs.*bgu)/n;
    
    scaled_ATT_itr(bsi,1) = scaled_ATT_bs;
    ATE_itr(bsi,1) = ATE_bs;
    W_bs_Ghat_itr(bsi,1) = W_bs_Ghat;
    
    %% save each BS result
    dt_outp = now();
    dt=[dt_start,dt_cmpt,dt_outp];
    save([fpath_BS,sprintf('%s_BS_%d.mat',s_now_batch,bsi)],...
        'bsi','s_now_batch','Wb','maxWb','minWb','scaled_ATT_bs','ATE_bs','W_bs_Ghat',...
        'bsperm','dt','covariate');
    % progress markers
    fprintf('Boot #%d took %2.3f sec\n',bsi,toc(tic1))
    
    end
end

%% two-sided 95% confidence interval for the welfare gain
% Following K&T's appendix B and replication code, compute v* = vhats_abs = max([vhats_p -vhats_n],[],2);
vhats_abs = max([vhats_p -vhats_n],[],2);

% vhats_p corresponds to the maximization problem, thus solved values are
% positive. vhats_n corresponds to the minimization problem, thus solved
% values are negative. Max of (positive, -negative) is equivalent to taking
% the absolute value. 

 %% Export rule parameters for graphs 

switch winter_prevuse 
    case 1
        s_winter = '_winter';
    case 0
        s_winter = '';
end
if (tcost)>0 % cost_savings
    if (SMC)
        s_cost = 'SMC';
    else
        s_cost = 'PMC';
    end
elseif (tcost)==0
    s_cost='kwh';
end

if fullsample==0
    filename_coefs=sprintf('coef_quadrant_%s_%s%s_wave%1.0f.mat',covariate,s_cost,s_winter,wave);
else
    filename_coefs=sprintf('coef_quadrant_fullsample_%s_%s%s_wave%1.0f.mat',covariate,s_cost,s_winter,wave);
end

filename = sprintf('%s_%s',fbase,filename_coefs);
fpathf = fullfile(savedir,filename);
fprintf('Saving "%s" to "%s" ... ',filename,savedir); 
save(fpathf,'min_i1','min_i2','min_sign1','min_sign2','m_W',...
   'in_Ghat','v_x1','v_x2','Xu','nw','gu','n','covariate','vhats_n','vhats_p','minW','scaled_ATT_itr','ATE_itr','W_bs_Ghat_itr',...
    'in_Ghat_hh','boo1','boo2','xx1','xx2','Ind');
fprintf('   Done!\n');

%% Output for table 
percent=sum(nw.*in_Ghat)/n;
savings=sum(gu.*in_Ghat)/n;
if (tcost)
    total_savings=savings*n*12;
else
    total_savings=savings*n*12/1000;
end

ci_lb=(minW-prctile(vhats_abs,95))/n;
ci_ub=(minW+prctile(vhats_abs,95))/n;

savings_EWM_quadrant=nan(1,4);
savings_EWM_quadrant=array2table(savings_EWM_quadrant,'VariableNames',{'Covariates','Share\ treated','Net\ cost\ changes','Total\ cost\ changes'}); 
format short g
savings_EWM_quadrant.(1) = {covariate};
savings_EWM_quadrant.(2) =round(percent*100);
savings_EWM_quadrant.(3) =round(savings,2);
savings_EWM_quadrant.(4) =roundn(total_savings,3);

ci_lb=round(ci_lb,2);
ci_ub=round(ci_ub,2);

ci_table=nan(1,4);
ci_table=array2table(ci_table,'VariableNames',{'Covariates','Share\ treated','Net\ cost\ changes','Total\ cost\ changes'});
cilb_string=string(ci_lb);
ciub_string=string(ci_ub);
ci_for_table=strcat('(',cilb_string,',',ciub_string,')');
ci_table.(1)="";
ci_table.(2)="";
ci_table.(3)=ci_for_table;
ci_table.(4)="";
savings_EWM_quadrant=[savings_EWM_quadrant;ci_table];


end
